import seq_cifar100
import torchvision.transforms as transforms
from datasets.utils.validation import get_train_val
from utils.conf import base_path_dataset as base_path
from torch.utils.data import DataLoader, Subset
import numpy as np
import torch


class HetroCIFAR100(seq_cifar100.SequentialCIFAR100):

    NAME = 'hetro-cifar100'
    SETTING = 'class-il'
    N_CLASSES_PER_TASK = 100
    N_TASKS = 20

    def __init__(self, args):
        self.num_classes = 100
        super().__init__(args=args)
        self.task_num = 0

        # need to find a better way to generate heterogeneous tasks (in the number of classes in the task)
        #self.task_class_nums = np.random.choice(8, 30, p=[0.2, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15])+2
        self.task_class_nums = [9, 2, 7, 3, 4, 9, 8, 3, 3, 7, 4, 4, 5, 9, 4, 5, 2, 8, 2, 2]

    def get_data_loaders(self):
        transform = self.TRANSFORM

        test_transform = transforms.Compose(
            [transforms.ToTensor(), self.get_normalization_transform()])

        train_dataset = seq_cifar100.MyCIFAR100(base_path() + 'CIFAR100', train=True,
                                                download=True, transform=transform)
        if self.args.validation:
            train_dataset, test_dataset = get_train_val(train_dataset,
                                                        test_transform, self.NAME)
        else:
            test_dataset = seq_cifar100.TCIFAR100(base_path() + 'CIFAR100', train=False,
                                                  download=True, transform=test_transform)

        train_loader, test_loader = get_hetro_split(self, train_dataset, test_dataset)
        return train_loader, test_loader


def get_hetro_split(setting, train_dataset, test_dataset):

    train_mask = np.logical_and(np.array(train_dataset.targets) >= setting.i,
                                np.array(train_dataset.targets) < setting.i + setting.task_class_nums[setting.task_num])
    test_mask = np.logical_and(np.array(test_dataset.targets) >= setting.i,
                               np.array(test_dataset.targets) < setting.i + setting.task_class_nums[setting.task_num])

    train_indxs = train_mask.nonzero()[0]
    test_indxs = test_mask.nonzero()[0]

    train_dataset.data = train_dataset.data[train_indxs]
    test_dataset.data = test_dataset.data[test_indxs]

    train_dataset.targets = np.array(train_dataset.targets)[train_indxs]
    test_dataset.targets = np.array(test_dataset.targets)[test_indxs]

    train_loader = DataLoader(train_dataset,
                              batch_size=setting.args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset,
                             batch_size=setting.args.batch_size, shuffle=False, num_workers=4)

    setting.test_loaders.append(test_loader)
    setting.train_loader = train_loader
    setting.i += setting.task_class_nums[setting.task_num]
    setting.task_num += 1
    return train_loader, test_loader

